{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Waymo Open Dataset 3D Semantic Segmentation Tutorial.ipynb",
      "provenance": [
        {
          "file_id": "tutorial_3d_semseg.ipynb",
          "timestamp": 1649874845881
        },
        {
          "file_id": "tutorial.ipynb",
          "timestamp": 1644287712198
        }
      ],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true,
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-pVhOfzLx9us"
      },
      "source": [
        "#Waymo Open Dataset 3D Semantic Segmentation Tutorial\n",
        "\n",
        "- Website: https://waymo.com/open\n",
        "- GitHub: https://github.com/waymo-research/waymo-open-dataset\n",
        "\n",
        "This tutorial demonstrates how to decode and interpret the 3D semantic segmentation labels. Visit the [Waymo Open Dataset Website](https://waymo.com/open) to download the full dataset.\n",
        "\n",
        "To use, open this notebook in [Colab](https://colab.research.google.com).\n",
        "\n",
        "Uncheck the box \"Reset all runtimes before running\" if you run this colab directly from the remote kernel. Alternatively, you can make a copy before trying to run it by following \"File > Save copy in Drive ...\".\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Package Installation"
      ],
      "metadata": {
        "id": "1sPLur9kMaLh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Package installation\n",
        "Please follow the instructions in [tutorial.ipynb](https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial.ipynb)."
      ],
      "metadata": {
        "id": "iEsf_G5_MeS-"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Imports and global definitions"
      ],
      "metadata": {
        "id": "rqs8_62VNc4T"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Data location. Please edit.\n",
        "\n",
        "# A tfrecord containing tf.Example protos as downloaded from the Waymo dataset\n",
        "# webpage.\n",
        "\n",
        "# Replace this path with your own tfrecords.\n",
        "FILENAME = '/content/waymo-od/tutorial/.../tfexample.tfrecord'"
      ],
      "metadata": {
        "id": "YuNAlbQpNkLa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow.compat.v1 as tf\n",
        "import numpy as np\n",
        "\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "from waymo_open_dataset.utils import  frame_utils\n",
        "from waymo_open_dataset import dataset_pb2 as open_dataset\n",
        "from waymo_open_dataset.protos import segmentation_metrics_pb2\n",
        "from waymo_open_dataset.protos import segmentation_submission_pb2"
      ],
      "metadata": {
        "id": "xCDNLdp9Ni8a"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ibor0U9XBlX6"
      },
      "source": [
        "# Read 3D semantic segmentation labels from Frame proto\n",
        " Note that only a subset of the frames have 3d semseg labels."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')\n",
        "for data in dataset:\n",
        "    frame = open_dataset.Frame()\n",
        "    frame.ParseFromString(bytearray(data.numpy()))\n",
        "    if frame.lasers[0].ri_return1.segmentation_label_compressed:\n",
        "      break"
      ],
      "metadata": {
        "id": "O41R3lljM9Ym"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(frame.context.name)\n",
        "print(frame.context.stats)"
      ],
      "metadata": {
        "id": "opFz4B9JXC7p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wHK95_JBUXUx"
      },
      "source": [
        "(range_images, camera_projections, segmentation_labels,\n",
        " range_image_top_pose) = frame_utils.parse_range_image_and_camera_projection(\n",
        "    frame)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(segmentation_labels[open_dataset.LaserName.TOP][0].shape.dims)"
      ],
      "metadata": {
        "id": "fgCDPt9zeV_k"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Visualize Segmentation Labels in Range Images"
      ],
      "metadata": {
        "id": "T4tCSnwFR3uh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "plt.figure(figsize=(64, 20))\n",
        "def plot_range_image_helper(data, name, layout, vmin = 0, vmax=1, cmap='gray'):\n",
        "  \"\"\"Plots range image.\n",
        "\n",
        "  Args:\n",
        "    data: range image data\n",
        "    name: the image title\n",
        "    layout: plt layout\n",
        "    vmin: minimum value of the passed data\n",
        "    vmax: maximum value of the passed data\n",
        "    cmap: color map\n",
        "  \"\"\"\n",
        "  plt.subplot(*layout)\n",
        "  plt.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax)\n",
        "  plt.title(name)\n",
        "  plt.grid(False)\n",
        "  plt.axis('off')\n",
        "\n",
        "def get_semseg_label_image(laser_name, return_index):\n",
        "  \"\"\"Returns semseg label image given a laser name and its return index.\"\"\"\n",
        "  return segmentation_labels[laser_name][return_index]\n",
        "\n",
        "def show_semseg_label_image(semseg_label_image, layout_index_start = 1):\n",
        "  \"\"\"Shows range image.\n",
        "\n",
        "  Args:\n",
        "    show_semseg_label_image: the semseg label data of type MatrixInt32.\n",
        "    layout_index_start: layout offset\n",
        "  \"\"\"\n",
        "  semseg_label_image_tensor = tf.convert_to_tensor(semseg_label_image.data)\n",
        "  semseg_label_image_tensor = tf.reshape(\n",
        "      semseg_label_image_tensor, semseg_label_image.shape.dims)\n",
        "  instance_id_image = semseg_label_image_tensor[...,0] \n",
        "  semantic_class_image = semseg_label_image_tensor[...,1]\n",
        "  plot_range_image_helper(instance_id_image.numpy(), 'instance id',\n",
        "                   [8, 1, layout_index_start], vmin=-1, vmax=200, cmap='Paired')\n",
        "  plot_range_image_helper(semantic_class_image.numpy(), 'semantic class',\n",
        "                   [8, 1, layout_index_start + 1], vmin=0, vmax=22, cmap='tab20')\n",
        "\n",
        "frame.lasers.sort(key=lambda laser: laser.name)\n",
        "show_semseg_label_image(get_semseg_label_image(open_dataset.LaserName.TOP, 0), 1)\n",
        "show_semseg_label_image(get_semseg_label_image(open_dataset.LaserName.TOP, 1), 3)"
      ],
      "metadata": {
        "id": "yRE_3QhMR7KQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Point Cloud Conversion and Visualization"
      ],
      "metadata": {
        "id": "C097jvXXR71D"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def convert_range_image_to_point_cloud_labels(frame,\n",
        "                                              range_images,\n",
        "                                              segmentation_labels,\n",
        "                                              ri_index=0):\n",
        "  \"\"\"Convert segmentation labels from range images to point clouds.\n",
        "\n",
        "  Args:\n",
        "    frame: open dataset frame\n",
        "    range_images: A dict of {laser_name, [range_image_first_return,\n",
        "       range_image_second_return]}.\n",
        "    segmentation_labels: A dict of {laser_name, [range_image_first_return,\n",
        "       range_image_second_return]}.\n",
        "    ri_index: 0 for the first return, 1 for the second return.\n",
        "\n",
        "  Returns:\n",
        "    point_labels: {[N, 2]} list of 3d lidar points's segmentation labels. 0 for\n",
        "      points that are not labeled.\n",
        "  \"\"\"\n",
        "  calibrations = sorted(frame.context.laser_calibrations, key=lambda c: c.name)\n",
        "  point_labels = []\n",
        "  for c in calibrations:\n",
        "    range_image = range_images[c.name][ri_index]\n",
        "    range_image_tensor = tf.reshape(\n",
        "        tf.convert_to_tensor(range_image.data), range_image.shape.dims)\n",
        "    range_image_mask = range_image_tensor[..., 0] > 0\n",
        "\n",
        "    if c.name in segmentation_labels:\n",
        "      sl = segmentation_labels[c.name][ri_index]\n",
        "      sl_tensor = tf.reshape(tf.convert_to_tensor(sl.data), sl.shape.dims)\n",
        "      sl_points_tensor = tf.gather_nd(sl_tensor, tf.where(range_image_mask))\n",
        "    else:\n",
        "      num_valid_point = tf.math.reduce_sum(tf.cast(range_image_mask, tf.int32))\n",
        "      sl_points_tensor = tf.zeros([num_valid_point, 2], dtype=tf.int32)\n",
        "      \n",
        "    point_labels.append(sl_points_tensor.numpy())\n",
        "  return point_labels"
      ],
      "metadata": {
        "id": "QI56Rt1BPuOW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "points, cp_points = frame_utils.convert_range_image_to_point_cloud(\n",
        "    frame, range_images, camera_projections, range_image_top_pose)\n",
        "points_ri2, cp_points_ri2 = frame_utils.convert_range_image_to_point_cloud(\n",
        "    frame, range_images, camera_projections, range_image_top_pose, ri_index=1)"
      ],
      "metadata": {
        "id": "PSAbxSW0RAQo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "point_labels = convert_range_image_to_point_cloud_labels(\n",
        "    frame, range_images, segmentation_labels)\n",
        "point_labels_ri2 = convert_range_image_to_point_cloud_labels(\n",
        "    frame, range_images, segmentation_labels, ri_index=1)"
      ],
      "metadata": {
        "id": "jgUcnL9IRMZs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# 3d points in vehicle frame.\n",
        "points_all = np.concatenate(points, axis=0)\n",
        "points_all_ri2 = np.concatenate(points_ri2, axis=0)\n",
        "# point labels.\n",
        "point_labels_all = np.concatenate(point_labels, axis=0)\n",
        "point_labels_all_ri2 = np.concatenate(point_labels_ri2, axis=0)\n",
        "# camera projection corresponding to each point.\n",
        "cp_points_all = np.concatenate(cp_points, axis=0)\n",
        "cp_points_all_ri2 = np.concatenate(cp_points_ri2, axis=0)"
      ],
      "metadata": {
        "id": "FdYRcWdRRWmc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PCj3SDbuq9Nr"
      },
      "source": [
        "###Show colored point cloud\n",
        "Example of rendered point clouds (this tutorial does not have visualization capability)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K8VFnGnOq6cO"
      },
      "source": [
        "from IPython.display import Image, display\n",
        "display(Image('/content/waymo-od/tutorial/3d_semseg_points.png'))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Create a dummy submission file for the validation set"
      ],
      "metadata": {
        "id": "37g8RIQhNMm7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import zlib\n",
        "\n",
        "def compress_array(array: np.ndarray, is_int32: bool = False):\n",
        "  \"\"\"Compress a numpy array to ZLIP compressed serialized MatrixFloat/Int32.\n",
        "\n",
        "  Args:\n",
        "    array: A numpy array.\n",
        "    is_int32: If true, use MatrixInt32, otherwise use MatrixFloat.\n",
        "\n",
        "  Returns:\n",
        "    The compressed bytes.\n",
        "  \"\"\"\n",
        "  if is_int32:\n",
        "    m = open_dataset.MatrixInt32()\n",
        "  else:\n",
        "    m = open_dataset.MatrixFloat()\n",
        "  m.shape.dims.extend(list(array.shape))\n",
        "  m.data.extend(array.reshape([-1]).tolist())\n",
        "  return zlib.compress(m.SerializeToString())\n",
        "\n",
        "def decompress_array(array_compressed: bytes, is_int32: bool = False):\n",
        "  \"\"\"Decompress bytes (of serialized MatrixFloat/Int32) to a numpy array.\n",
        "\n",
        "  Args:\n",
        "    array_compressed: bytes.\n",
        "    is_int32: If true, use MatrixInt32, otherwise use MatrixFloat.\n",
        "\n",
        "  Returns:\n",
        "    The decompressed numpy array.\n",
        "  \"\"\"\n",
        "  decompressed = zlib.decompress(array_compressed)\n",
        "  if is_int32:\n",
        "    m = open_dataset.MatrixInt32()\n",
        "    dtype = np.int32\n",
        "  else:\n",
        "    m = open_dataset.MatrixFloat()\n",
        "    dtype = np.float32\n",
        "  m.ParseFromString(decompressed)\n",
        "  return np.array(m.data, dtype=dtype).reshape(m.shape.dims)"
      ],
      "metadata": {
        "id": "YaC7-BrlNQIz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "TOP_LIDAR_ROW_NUM = 64\n",
        "TOP_LIDAR_COL_NUM = 2650"
      ],
      "metadata": {
        "id": "bZTjFN_CNlNx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_range_image_point_indexing(range_images, ri_index=0):\n",
        "  \"\"\"Get the indices of the valid points (of the TOP lidar) in the range image.\n",
        "\n",
        "  The order of the points match those from convert_range_image_to_point_cloud\n",
        "  and convert_range_image_to_point_cloud_labels.\n",
        "\n",
        "  Args:\n",
        "    range_images: A dict of {laser_name, [range_image_first_return,\n",
        "       range_image_second_return]}.\n",
        "    ri_index: 0 for the first return, 1 for the second return.\n",
        "\n",
        "  Returns:\n",
        "    points_indexing_top: (N, 2) col and row indices of the points in the\n",
        "      TOP lidar.\n",
        "  \"\"\"\n",
        "  points_indexing_top = None\n",
        "  xgrid, ygrid = np.meshgrid(range(TOP_LIDAR_COL_NUM), range(TOP_LIDAR_ROW_NUM))\n",
        "  col_row_inds_top = np.stack([xgrid, ygrid], axis=-1)\n",
        "  range_image = range_images[open_dataset.LaserName.TOP][ri_index]\n",
        "  range_image_tensor = tf.reshape(\n",
        "      tf.convert_to_tensor(range_image.data), range_image.shape.dims)\n",
        "  range_image_mask = range_image_tensor[..., 0] > 0\n",
        "  points_indexing_top = col_row_inds_top[np.where(range_image_mask)]\n",
        "  return points_indexing_top"
      ],
      "metadata": {
        "id": "5P9hToEvNUUJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def dummy_semseg_for_one_frame(frame, dummy_class=14):\n",
        "  \"\"\"Assign all valid points to a single dummy class.\n",
        "\n",
        "  Args:\n",
        "    frame: An Open Dataset Frame proto.\n",
        "    dummy_class: The class to assign to. Default is 14 (building).\n",
        "\n",
        "  Returns:\n",
        "    segmentation_frame: A SegmentationFrame proto.\n",
        "  \"\"\"\n",
        "  (range_images, camera_projections, segmentation_labels,\n",
        "   range_image_top_pose) = frame_utils.parse_range_image_and_camera_projection(\n",
        "       frame)\n",
        "  # Get the col, row indices of the valid points.\n",
        "  points_indexing_top = get_range_image_point_indexing(range_images, ri_index=0)\n",
        "  points_indexing_top_ri2 = get_range_image_point_indexing(\n",
        "      range_images, ri_index=1)\n",
        "  # Assign the dummy class to all valid points (in the range image)\n",
        "  range_image_pred = np.zeros(\n",
        "      (TOP_LIDAR_ROW_NUM, TOP_LIDAR_COL_NUM, 2), dtype=np.int32)\n",
        "  range_image_pred[points_indexing_top[:, 1],\n",
        "                   points_indexing_top[:, 0], 1] = dummy_class\n",
        "  range_image_pred_ri2 = np.zeros(\n",
        "      (TOP_LIDAR_ROW_NUM, TOP_LIDAR_COL_NUM, 2), dtype=np.int32)\n",
        "  range_image_pred_ri2[points_indexing_top_ri2[:, 1],\n",
        "                       points_indexing_top_ri2[:, 0], 1] = dummy_class\n",
        "  # Construct the SegmentationFrame proto.\n",
        "  segmentation_frame = segmentation_metrics_pb2.SegmentationFrame()\n",
        "  segmentation_frame.context_name = frame.context.name\n",
        "  segmentation_frame.frame_timestamp_micros = frame.timestamp_micros\n",
        "  laser_semseg = open_dataset.Laser()\n",
        "  laser_semseg.name = open_dataset.LaserName.TOP\n",
        "  laser_semseg.ri_return1.segmentation_label_compressed = compress_array(\n",
        "      range_image_pred, is_int32=True)\n",
        "  laser_semseg.ri_return2.segmentation_label_compressed = compress_array(\n",
        "      range_image_pred_ri2, is_int32=True)\n",
        "  segmentation_frame.segmentation_labels.append(laser_semseg)\n",
        "  return segmentation_frame"
      ],
      "metadata": {
        "id": "n0ijObegNgHc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create the dummy pred file for the validation set run segments.\n",
        "\n",
        "# Replace this path with the real path to the WOD validation set folder.\n",
        "folder_name = '/content/waymo-od/.../validation/'\n",
        "\n",
        "filenames = [os.path.join(folder_name, x) for x in os.listdir(\n",
        "    folder_name) if 'tfrecord' in x]\n",
        "assert(len(filenames) == 202)\n",
        "\n",
        "segmentation_frame_list = segmentation_metrics_pb2.SegmentationFrameList()\n",
        "for idx, filename in enumerate(filenames):\n",
        "  if idx % 10 == 0:\n",
        "    print('Processing %d/%d run segments...' % (idx, len(filenames)))\n",
        "  dataset = tf.data.TFRecordDataset(filename, compression_type='')\n",
        "  for data in dataset:\n",
        "    frame = open_dataset.Frame()\n",
        "    frame.ParseFromString(bytearray(data.numpy()))\n",
        "    if frame.lasers[0].ri_return1.segmentation_label_compressed:\n",
        "      segmentation_frame = dummy_semseg_for_one_frame(frame)\n",
        "      segmentation_frame_list.frames.append(segmentation_frame)\n",
        "print('Total number of frames: ', len(segmentation_frame_list.frames))"
      ],
      "metadata": {
        "id": "W3iz5MazOhPy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create the submission file, which can be uploaded to the eval server.\n",
        "submission = segmentation_submission_pb2.SemanticSegmentationSubmission()\n",
        "submission.account_name = 'joe@gmail.com'\n",
        "submission.unique_method_name = 'JoeNet'\n",
        "submission.affiliation = 'Smith Inc.'\n",
        "submission.authors.append('Joe Smith')\n",
        "submission.description = \"A dummy method by Joe (val set).\"\n",
        "submission.method_link = 'NA'\n",
        "submission.sensor_type = 1\n",
        "submission.number_past_frames_exclude_current = 2\n",
        "submission.number_future_frames_exclude_current = 0\n",
        "submission.inference_results.CopyFrom(segmentation_frame_list)"
      ],
      "metadata": {
        "id": "0_YlqK4RR8pR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "output_filename = '/tmp/wod_semseg_val_set_dummy_pred_submission.bin'\n",
        "f = open(output_filename, 'wb')\n",
        "f.write(submission.SerializeToString())\n",
        "f.close()"
      ],
      "metadata": {
        "id": "hrCD59OkR_wK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Create a dummy submission file for the testing set"
      ],
      "metadata": {
        "id": "GIjNwn3bSTec"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create the dummy pred file for the testing set run segments.\n",
        "\n",
        "# Replace the paths with the real paths to the WOD testing set folders.\n",
        "folder_name1 = '/content/waymo-od/.../testing/'\n",
        "folder_name2 = '/content/waymo-od/.../testing_location/'\n",
        "filenames1 = [os.path.join(folder_name1, x) for x in os.listdir(\n",
        "    folder_name1) if 'tfrecord' in x]\n",
        "filenames2 = [os.path.join(folder_name2, x) for x in os.listdir(\n",
        "    folder_name2) if 'tfrecord' in x]\n",
        "filenames = filenames1 + filenames2\n",
        "print(len(filenames))\n",
        "assert(len(filenames) == 150)\n",
        "\n",
        "# Replace this path with the real path. The file is under:\n",
        "# /waymo-open-dataset/tutorial/ in the github repo.\n",
        "# Each line of the file is the \"<context_name>, <timestamp_micros>\" of a frame\n",
        "# with semseg labels. \n",
        "testing_set_frame_file = '/path/3d_semseg_test_set_frames.txt'\n",
        "context_name_timestamp_tuples = [x.rstrip().split(',') for x in (\n",
        "    open(testing_set_frame_file, 'r').readlines())]\n",
        "\n",
        "segmentation_frame_list = segmentation_metrics_pb2.SegmentationFrameList()\n",
        "for idx, filename in enumerate(filenames):\n",
        "  if idx % 10 == 0:\n",
        "    print('Processing %d/%d run segments...' % (idx, len(filenames)))\n",
        "  dataset = tf.data.TFRecordDataset(filename, compression_type='')\n",
        "  for data in dataset:\n",
        "    frame = open_dataset.Frame()\n",
        "    frame.ParseFromString(bytearray(data.numpy()))\n",
        "    context_name = frame.context.name\n",
        "    timestamp = frame.timestamp_micros\n",
        "    if (context_name, str(timestamp)) in context_name_timestamp_tuples:\n",
        "      print(context_name, timestamp)\n",
        "      segmentation_frame = dummy_semseg_for_one_frame(frame)\n",
        "      segmentation_frame_list.frames.append(segmentation_frame)\n",
        "print('Total number of frames: ', len(segmentation_frame_list.frames))"
      ],
      "metadata": {
        "id": "4-fJxQHXSVVk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create the submission file, which can be uploaded to the eval server.\n",
        "submission = segmentation_submission_pb2.SemanticSegmentationSubmission()\n",
        "submission.account_name = 'joe@gmail.com'\n",
        "submission.unique_method_name = 'JoeNet'\n",
        "submission.affiliation = 'Smith Inc.'\n",
        "submission.authors.append('Joe Smith')\n",
        "submission.description = \"A dummy method by Joe (test set).\"\n",
        "submission.method_link = 'NA'\n",
        "submission.sensor_type = 1\n",
        "submission.number_past_frames_exclude_current = 2\n",
        "submission.number_future_frames_exclude_current = 0\n",
        "submission.inference_results.CopyFrom(segmentation_frame_list)"
      ],
      "metadata": {
        "id": "30q-hzjQljYc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "output_filename = '/tmp/wod_semseg_test_set_dummy_pred_submission.bin'\n",
        "f = open(output_filename, 'wb')\n",
        "f.write(submission.SerializeToString())\n",
        "f.close()"
      ],
      "metadata": {
        "id": "oStSwaE-Sv5Y"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}
